
#include "three_nn_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include <algorithm>

namespace optiling {
const uint32_t BLOCK_SIZE = 32*3*4;
static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
    ThreeNNTilingData tiling;

    uint64_t ubSize;
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
    // auto socVersion = ascendcPlatform.GetSocVersion();
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); //获取硬件平台存储空间 UB 的内存大小
    uint32_t shapeB = context->GetInputTensor(0)->GetOriginShape().GetDim(0);
    uint32_t shapeN = context->GetInputTensor(0)->GetOriginShape().GetDim(1);
    uint32_t shapeM = context->GetInputTensor(1)->GetOriginShape().GetDim(1);

    //获取输入shape信息
    uint32_t inputNum = context->GetInputShape(1)->GetStorageShape().GetShapeSize()/shapeB; //输入数量 xyz2
    uint32_t inputBytes = GetSizeByDataType(context->GetInputDesc(1)->GetDataType()); //输入类型
    uint32_t inputLength = inputBytes * inputNum; //输入长度

    //可使用的ub空间 输入3输出1，手动考虑双缓存
    uint32_t ubDataNumber = 4;//(inputBytes == 2) ? 10 : 6;

    // The number of 32B data blocks that can be used for each data. DOUBLE BUFFER is already counted here
    uint32_t tileBlockNum = (ubSize / BLOCK_SIZE) / ubDataNumber; //每个ub段可用的空间块数
    uint32_t tileDataNum = 1440;

    // Input data for 32B alignment
    uint32_t inputLengthAlgin32 = (((inputLength + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE); //输入长度 对齐处理
    uint32_t everyCoreInputBlockNum = inputLengthAlgin32 / BLOCK_SIZE;// 输入数据需要多少空间块

    uint32_t TileNum = inputNum / tileDataNum;
    uint32_t finalTileNum = (inputNum % tileDataNum) == 0 ? TileNum : TileNum + 1; //需要循环处理几次
    // Tail block calculation for  chunks of data
    uint32_t TailDataNum = inputNum  - (tileDataNum * TileNum);//CoreDataNum - (tileDataNum * TileNum);
    TailDataNum = TailDataNum == 0 ? tileDataNum : TailDataNum; //最后一次需要处理的数据量

    tiling.set_CoreDataNum((inputNum*shapeB+31)/32*32);  //对齐空间后的输入数量
    tiling.set_finalTileNum(finalTileNum);//需要循环处理几次
    tiling.set_tileDataNum(tileDataNum); //每个b处理的数据量
    tiling.set_TailDataNum(TailDataNum); //最后一次需要处理的数据量

    tiling.set_shapeB(shapeB);
    tiling.set_shapeN(shapeN); 
    tiling.set_shapeM(shapeM); 
    
    context->SetBlockDim(1);
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
    size_t *currentWorkspace = context->GetWorkspaceSizes(1);
    currentWorkspace[0] = 0;
    return ge::GRAPH_SUCCESS;
}
}


namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext* context)
{
    const gert::Shape* x1_shape = context->GetInputShape(0);
    gert::Shape* y_shape = context->GetOutputShape(0);
    *y_shape = *x1_shape;
    return GRAPH_SUCCESS;
}
}


namespace ops {
class ThreeNN : public OpDef {
public:
    explicit ThreeNN(const char* name) : OpDef(name)
    {
        this->Input("xyz1")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Input("xyz2")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("dist")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("indices")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});

        this->SetInferShape(ge::InferShape);

        this->AICore()
            .SetTiling(optiling::TilingFunc);
        this->AICore().AddConfig("ascend310b");

    }
};

OP_ADD(ThreeNN);
}
